This time you'll find yourself delving into the heart (and other intestines) of recurrent neural networks on a class of toy problems.
Struggle to find a name for the variable? Let's see how you'll come up with a name for your son/daughter. Surely no human has expertize over what is a good child name, so let us train RNN instead;
It's dangerous to go alone, take these:
In [ ]:
import sys
if 'google.colab' in sys.modules:
!wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/99ae2a3dae648428edbfc41fd10ed688e5365161/week07_%5Brecap%5D_rnn/names -O names
In [ ]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
In [ ]:
import os
start_token = " "
with open("names") as f:
lines = f.read()[:-1].split('\n')
lines = [start_token + line for line in lines]
In [ ]:
print ('n samples = ',len(lines))
for x in lines[::1000]:
print (x)
In [ ]:
MAX_LENGTH = max(map(len, lines))
print("max length =", MAX_LENGTH)
plt.title('Sequence length distribution')
plt.hist(list(map(len, lines)),bins=25);
In [ ]:
#all unique characters go here
tokens = <all unique characters in the dataset>
tokens = list(tokens)
num_tokens = len(tokens)
print ('num_tokens = ', num_tokens)
assert 50 < num_tokens < 60, "Names should contain within 50 and 60 unique tokens depending on encoding"
In [ ]:
token_to_id = <dictionary of symbol -> its identifier (index in tokens list)>
In [ ]:
assert len(tokens) == len(token_to_id), "dictionaries must have same size"
for i in range(num_tokens):
assert token_to_id[tokens[i]] == i, "token identifier must be it's position in tokens list"
print("Seems alright!")
In [ ]:
def to_matrix(lines, max_len=None, pad=token_to_id[' '], dtype='int32', batch_first = True):
"""Casts a list of names into rnn-digestable matrix"""
max_len = max_len or max(map(len, lines))
lines_ix = np.zeros([len(lines), max_len], dtype) + pad
for i in range(len(lines)):
line_ix = [token_to_id[c] for c in lines[i]]
lines_ix[i, :len(line_ix)] = line_ix
if not batch_first: # convert [batch, time] into [time, batch]
lines_ix = np.transpose(lines_ix)
return lines_ix
In [ ]:
#Example: cast 4 random names to matrices, pad with zeros
print('\n'.join(lines[::2000]))
print(to_matrix(lines[::2000]))
We can rewrite recurrent neural network as a consecutive application of dense layer to input $x_t$ and previous rnn state $h_t$. This is exactly what we're gonna do now.
Since we're training a language model, there should also be:
In [ ]:
import torch, torch.nn as nn
import torch.nn.functional as F
class CharRNNCell(nn.Module):
"""
Implement the scheme above as torch module
"""
def __init__(self, num_tokens=len(tokens), embedding_size=16, rnn_num_units=64):
super(self.__class__,self).__init__()
self.num_units = rnn_num_units
self.embedding = nn.Embedding(num_tokens, embedding_size)
self.rnn_update = nn.Linear(embedding_size + rnn_num_units, rnn_num_units)
self.rnn_to_logits = nn.Linear(rnn_num_units, num_tokens)
def forward(self, x, h_prev):
"""
This method computes h_next(x, h_prev) and log P(x_next | h_next)
We'll call it repeatedly to produce the whole sequence.
:param x: batch of character ids, int64[batch_size]
:param h_prev: previous rnn hidden states, float32 matrix [batch, rnn_num_units]
"""
# get vector embedding of x
x_emb = self.embedding(x)
# compute next hidden state using self.rnn_update
# hint: use torch.cat(..., dim=...) for concatenation
h_next = <YOUR CODE>
h_next = torch.tanh(h_next)
assert h_next.size() == h_prev.size()
#compute logits for next character probs
logits = <YOUR CODE>
return h_next, F.log_softmax(logits, -1)
def initial_state(self, batch_size):
""" return rnn state before it processes first input (aka h0) """
return torch.zeros(batch_size, self.num_units)
In [ ]:
char_rnn = CharRNNCell()
In [ ]:
def rnn_loop(char_rnn, batch_ix):
"""
Computes log P(next_character) for all time-steps in lines_ix
:param lines_ix: an int32 matrix of shape [batch, time], output of to_matrix(lines)
"""
batch_size, max_length = batch_ix.size()
hid_state = char_rnn.initial_state(batch_size)
logprobs = []
for x_t in batch_ix.transpose(0,1):
hid_state, logp_next = char_rnn(x_t, hid_state) # <-- here we call your one-step code
logprobs.append(logp_next)
return torch.stack(logprobs, dim=1)
In [ ]:
batch_ix = to_matrix(lines[:5])
batch_ix = torch.tensor(batch_ix, dtype=torch.int64)
logp_seq = rnn_loop(char_rnn, batch_ix)
assert torch.max(logp_seq).data.numpy() <= 0
assert tuple(logp_seq.size()) == batch_ix.shape + (num_tokens,)
We can now train our neural network to minimize crossentropy (maximize log-likelihood) with the actual next tokens.
To do so in a vectorized manner, we take batch_ix[:, 1:]
- a matrix of token ids shifted i step to the left so i-th element is acutally the "next token" for i-th prediction
In [ ]:
predictions_logp = logp_seq[:, :-1]
actual_next_tokens = batch_ix[:, 1:]
logp_next = torch.gather(predictions_logp, dim=2, index=actual_next_tokens[:,:,None])
loss = -logp_next.mean()
In [ ]:
loss.backward()
In [ ]:
for w in char_rnn.parameters():
assert w.grad is not None and torch.max(torch.abs(w.grad)).data.numpy() != 0, \
"Loss is not differentiable w.r.t. a weight with shape %s. Check forward method." % (w.size(),)
In [ ]:
from IPython.display import clear_output
from random import sample
char_rnn = CharRNNCell()
opt = torch.optim.Adam(char_rnn.parameters())
history = []
In [ ]:
for i in range(1000):
batch_ix = to_matrix(sample(lines, 32), max_len=MAX_LENGTH)
batch_ix = torch.tensor(batch_ix, dtype=torch.int64)
logp_seq = rnn_loop(char_rnn, batch_ix)
# compute loss
<YOUR CODE>
loss = <YOUR CODE>
# train with backprop
<YOUR CODE>
history.append(loss.data.numpy())
if (i+1)%100==0:
clear_output(True)
plt.plot(history,label='loss')
plt.legend()
plt.show()
assert np.mean(history[:10]) > np.mean(history[-10:]), "RNN didn't converge."
In [ ]:
def generate_sample(char_rnn, seed_phrase=' ', max_length=MAX_LENGTH, temperature=1.0):
'''
The function generates text given a phrase of length at least SEQ_LENGTH.
:param seed_phrase: prefix characters. The RNN is asked to continue the phrase
:param max_length: maximum output length, including seed_phrase
:param temperature: coefficient for sampling. higher temperature produces more chaotic outputs,
smaller temperature converges to the single most likely output
'''
x_sequence = [token_to_id[token] for token in seed_phrase]
x_sequence = torch.tensor([x_sequence], dtype=torch.int64)
hid_state = char_rnn.initial_state(batch_size=1)
#feed the seed phrase, if any
for i in range(len(seed_phrase) - 1):
hid_state, _ = char_rnn(x_sequence[:, i], hid_state)
#start generating
for _ in range(max_length - len(seed_phrase)):
hid_state, logp_next = char_rnn(x_sequence[:, -1], hid_state)
p_next = F.softmax(logp_next / temperature, dim=-1).data.numpy()[0]
# sample next token and push it back into x_sequence
next_ix = np.random.choice(num_tokens,p=p_next)
next_ix = torch.tensor([[next_ix]], dtype=torch.int64)
x_sequence = torch.cat([x_sequence, next_ix], dim=1)
return ''.join([tokens[ix] for ix in x_sequence.data.numpy()[0]])
In [ ]:
for _ in range(10):
print(generate_sample(char_rnn))
In [ ]:
for _ in range(50):
print(generate_sample(char_rnn, seed_phrase=' Trump'))
You've just implemented a recurrent language model that can be tasked with generating any kind of sequence, so there's plenty of data you can try it on:
If you're willing to give it a try, here's what you wanna look at:
Selenium
or Scrapy
for that.Good hunting!
What we just did is a manual low-level implementation of RNN. While it's cool, i guess you won't like the idea of re-writing it from scratch on every occasion.
As you might have guessed, torch has a solution for this. To be more specific, there are two options:
nn.RNNCell(emb_size, rnn_num_units)
- implements a single step of RNN just like you did. Basically concat-linear-tanhnn.RNN(emb_size, rnn_num_units
- implements the whole rnn_loop for you.There's also nn.LSTMCell
vs nn.LSTM
, nn.GRUCell
vs nn.GRU
, etc. etc.
In this example we'll rewrite the char_rnn and rnn_loop using high-level rnn API.
In [ ]:
class CharRNNLoop(nn.Module):
def __init__(self, num_tokens=num_tokens, emb_size=16, rnn_num_units=64):
super(self.__class__, self).__init__()
self.emb = nn.Embedding(num_tokens, emb_size)
self.rnn = nn.RNN(emb_size, rnn_num_units, batch_first=True)
self.hid_to_logits = nn.Linear(rnn_num_units, num_tokens)
def forward(self, x):
h_seq, _ = self.rnn(self.emb(x))
next_logits = self.hid_to_logits(h_seq)
next_logp = F.log_softmax(next_logits, dim=-1)
return next_logp
model = CharRNNLoop()
In [ ]:
# the model applies over the whole sequence
batch_ix = to_matrix(sample(lines, 32), max_len=MAX_LENGTH)
batch_ix = torch.tensor(batch_ix, dtype=torch.int64)
logp_seq = model(batch_ix)
# compute loss. This time we use nll_loss with some duct tape
loss = F.nll_loss(logp_seq[:, :-1].contiguous().view(-1, num_tokens),
batch_ix[:, 1:].contiguous().view(-1))
loss.backward()
Here's another example
In [ ]:
import torch, torch.nn as nn
import torch.nn.functional as F
class CharLSTMCell(nn.Module):
"""
Implements something like CharRNNCell, but with LSTM
"""
def __init__(self, num_tokens=len(tokens), embedding_size=16, rnn_num_units=64):
super(self.__class__,self).__init__()
self.num_units = rnn_num_units
self.emb = nn.Embedding(num_tokens, embedding_size)
self.lstm = nn.LSTMCell(embedding_size, rnn_num_units)
self.rnn_to_logits = nn.Linear(rnn_num_units, num_tokens)
def forward(self, x, prev_state):
(prev_h, prev_c) = prev_state
(next_h, next_c) = self.lstm(self.emb(x), (prev_h, prev_c))
logits = self.rnn_to_logits(next_h)
return (next_h, next_c), F.log_softmax(logits, -1)
def initial_state(self, batch_size):
""" LSTM has two state variables, cell and hid """
return torch.zeros(batch_size, self.num_units), torch.zeros(batch_size, self.num_units)
char_lstm = CharLSTMCell()
In [ ]:
# the model applies over the whole sequence
batch_ix = to_matrix(sample(lines, 32), max_len=MAX_LENGTH)
batch_ix = torch.tensor(batch_ix, dtype=torch.int64)
logp_seq = rnn_loop(char_lstm, batch_ix)
# compute loss. This time we use nll_loss with some duct tape
loss = F.nll_loss(logp_seq[:, :-1].contiguous().view(-1, num_tokens),
batch_ix[:, 1:].contiguous().view(-1))
loss.backward()
Bonus quest: implement a model that uses 2 LSTM layers (the second lstm uses the first as input) and train it on your data.